"""
Probe: Is there enough historical data on threshold crossings for a paper?

Questions:
1. Which countries crossed from <$9k to >$25k? When?
2. Which entered the zone but got stuck?
3. What predicts successful crossing?
4. Where are current threshold countries?
"""

import pandas as pd
import numpy as np

full = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel.csv')
panel = full[full['year'] <= 2024].copy()

# Also load resource rents
try:
    rents = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/raw/wdi_resource_rents.csv')
    panel = panel.merge(rents[['iso3', 'year', 'resource_rents_gdp']], on=['iso3', 'year'], how='left')
except:
    pass

# ═══════════════════════════════════════════════════════════════════════
# 1. Historical trajectory: classify countries by threshold crossing
# ═══════════════════════════════════════════════════════════════════════
print("=" * 70)
print("1. HISTORICAL THRESHOLD CROSSING ANALYSIS")
print("=" * 70)

# For each country, track GDP per capita trajectory
countries = panel[panel['gdp_pc_ppp'].notna()].groupby('iso3').agg(
    first_year=('year', 'min'),
    last_year=('year', 'max'),
    first_gdp=('gdp_pc_ppp', 'first'),
    last_gdp=('gdp_pc_ppp', 'last'),
    min_gdp=('gdp_pc_ppp', 'min'),
    max_gdp=('gdp_pc_ppp', 'max'),
    n_obs=('gdp_pc_ppp', 'count')
)

# Classify
def classify(row):
    if row['first_gdp'] < 9000 and row['last_gdp'] > 25000:
        return 'Crossed (below→above)'
    elif row['first_gdp'] < 9000 and 9000 <= row['last_gdp'] <= 25000:
        return 'In zone (entered from below)'
    elif row['first_gdp'] < 9000 and row['last_gdp'] < 9000:
        return 'Still below'
    elif 9000 <= row['first_gdp'] <= 25000 and row['last_gdp'] > 25000:
        return 'Crossed (zone→above)'
    elif 9000 <= row['first_gdp'] <= 25000 and 9000 <= row['last_gdp'] <= 25000:
        return 'Stuck in zone'
    elif 9000 <= row['first_gdp'] <= 25000 and row['last_gdp'] < 9000:
        return 'Fell back below'
    elif row['first_gdp'] > 25000:
        return 'Always above'
    else:
        return 'Other'

countries['status'] = countries.apply(classify, axis=1)

print("\nClassification summary:")
for status in ['Crossed (below→above)', 'Crossed (zone→above)', 'In zone (entered from below)',
               'Stuck in zone', 'Still below', 'Fell back below', 'Always above', 'Other']:
    sub = countries[countries['status'] == status]
    if len(sub) > 0:
        print(f"\n  {status}: {len(sub)} countries")
        if len(sub) <= 25:
            for iso, row in sub.sort_values('last_gdp', ascending=False).iterrows():
                print(f"    {iso}: ${row['first_gdp']:,.0f} ({int(row['first_year'])}) → "
                      f"${row['last_gdp']:,.0f} ({int(row['last_year'])})")

# ═══════════════════════════════════════════════════════════════════════
# 2. For countries that crossed: when did they cross $9k and $25k?
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("2. CROSSING DATES")
print("=" * 70)

crossed_all = countries[countries['status'].str.startswith('Crossed')].index.tolist()

crossing_data = []
for iso in crossed_all:
    c = panel[panel['iso3'] == iso].sort_values('year')
    gdp_series = c[['year', 'gdp_pc_ppp']].dropna()

    cross_9k = gdp_series[gdp_series['gdp_pc_ppp'] >= 9000]['year'].min()
    cross_25k = gdp_series[gdp_series['gdp_pc_ppp'] >= 25000]['year'].min()

    # Time in zone
    if pd.notna(cross_9k) and pd.notna(cross_25k):
        years_in_zone = cross_25k - cross_9k
    else:
        years_in_zone = np.nan

    # Demographics at entry and exit
    at_9k = c[c['year'] == cross_9k] if pd.notna(cross_9k) else pd.DataFrame()
    at_25k = c[c['year'] == cross_25k] if pd.notna(cross_25k) else pd.DataFrame()

    row = {'iso3': iso, 'cross_9k': cross_9k, 'cross_25k': cross_25k,
           'years_in_zone': years_in_zone}

    for var in ['Z_1', 'old_dep', 'kaopen', 'ca_gdp', 'resource_rents_gdp', 'gross_savings_gdp']:
        if var in c.columns:
            row[f'{var}_at_entry'] = at_9k[var].values[0] if len(at_9k) > 0 and var in at_9k.columns else np.nan
            row[f'{var}_at_exit'] = at_25k[var].values[0] if len(at_25k) > 0 and var in at_25k.columns else np.nan

    crossing_data.append(row)

cross_df = pd.DataFrame(crossing_data).sort_values('years_in_zone')

print("\nCountries that crossed the full threshold ($9k → $25k):")
print(f"{'Country':6s} {'Enter $9k':>10s} {'Exit $25k':>10s} {'Years':>6s} {'Z₁ entry':>9s} {'Z₁ exit':>8s} {'OADR entry':>11s} {'Rents entry':>12s}")
print("-" * 80)
for _, row in cross_df.iterrows():
    z1e = f"{row.get('Z_1_at_entry', np.nan):.2f}" if pd.notna(row.get('Z_1_at_entry')) else '—'
    z1x = f"{row.get('Z_1_at_exit', np.nan):.2f}" if pd.notna(row.get('Z_1_at_exit')) else '—'
    oadr = f"{row.get('old_dep_at_entry', np.nan)*100:.1f}%" if pd.notna(row.get('old_dep_at_entry')) else '—'
    rents = f"{row.get('resource_rents_gdp_at_entry', np.nan):.1f}%" if pd.notna(row.get('resource_rents_gdp_at_entry')) else '—'
    yrs = f"{row['years_in_zone']:.0f}" if pd.notna(row['years_in_zone']) else '—'
    print(f"{row['iso3']:6s} {int(row['cross_9k']) if pd.notna(row['cross_9k']) else '?':>10} "
          f"{int(row['cross_25k']) if pd.notna(row['cross_25k']) else '?':>10} "
          f"{yrs:>6s} {z1e:>9s} {z1x:>8s} {oadr:>11s} {rents:>12s}")

# ═══════════════════════════════════════════════════════════════════════
# 3. Current threshold cohort ($9k-$25k)
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("3. CURRENT THRESHOLD COHORT ($9k-$25k)")
print("=" * 70)

recent = panel[(panel['year'] >= 2020) & (panel['year'] <= 2024)]
current = recent.groupby('iso3').agg(
    gdp=('gdp_pc_ppp', 'last'),
    Z1=('Z_1', 'last'),
    oadr=('old_dep', 'last'),
    ca=('ca_gdp', 'last'),
    nfa=('nfa_gdp', 'last'),
    kaopen=('kaopen', 'last'),
    savings=('gross_savings_gdp', 'last') if 'gross_savings_gdp' in recent.columns else ('ca_gdp', 'count'),
    rents=('resource_rents_gdp', 'last') if 'resource_rents_gdp' in recent.columns else ('ca_gdp', 'count'),
    fiscal=('fiscal_bal_gdp', 'last'),
).dropna(subset=['gdp'])

zone = current[(current['gdp'] >= 9000) & (current['gdp'] <= 25000)].sort_values('gdp', ascending=False)

# Add projections
full_proj = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel.csv')
proj_2050 = full_proj[full_proj['year'] == 2050][['iso3', 'Z_1', 'old_dep']].rename(
    columns={'Z_1': 'Z1_2050', 'old_dep': 'oadr_2050'})
zone = zone.merge(proj_2050, left_index=True, right_on='iso3').set_index('iso3')

print(f"\n{len(zone)} countries in threshold zone:")
print(f"{'Country':6s} {'GDP/cap':>8s} {'Z₁':>6s} {'Z₁_50':>7s} {'OADR':>6s} {'OADR_50':>8s} {'CA':>6s} {'Rents':>6s} {'KAOPEN':>7s} {'Savings':>8s}")
print("-" * 80)
for iso, row in zone.iterrows():
    rents_str = f"{row['rents']:.1f}" if pd.notna(row.get('rents')) else '—'
    sav_str = f"{row['savings']:.1f}" if pd.notna(row.get('savings')) else '—'
    kao_str = f"{row['kaopen']:.1f}" if pd.notna(row.get('kaopen')) else '—'
    print(f"{iso:6s} {row['gdp']:8,.0f} {row['Z1']:6.2f} {row.get('Z1_2050', np.nan):7.2f} "
          f"{row['oadr']*100:5.1f}% {row.get('oadr_2050', np.nan)*100:7.1f}% "
          f"{row['ca']:6.1f} {rents_str:>6s} {kao_str:>7s} {sav_str:>8s}")

# ═══════════════════════════════════════════════════════════════════════
# 4. Quick regression: what predicts successful crossing?
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("4. WHAT PREDICTS CROSSING? (Cross-sectional)")
print("=" * 70)

# For countries that were ever in the zone, create: did they exit above?
ever_in_zone = panel[(panel['gdp_pc_ppp'] >= 9000) & (panel['gdp_pc_ppp'] <= 25000)]
zone_countries = ever_in_zone['iso3'].unique()

# For each, get characteristics at zone entry and whether they exited above
zone_profiles = []
for iso in zone_countries:
    c = panel[panel['iso3'] == iso].sort_values('year')
    gdp_s = c[c['gdp_pc_ppp'].notna()]
    if len(gdp_s) == 0:
        continue

    # First year in zone
    in_zone = gdp_s[(gdp_s['gdp_pc_ppp'] >= 9000) & (gdp_s['gdp_pc_ppp'] <= 25000)]
    if len(in_zone) == 0:
        continue
    entry_year = in_zone['year'].min()

    # Did they ever exit above?
    after = gdp_s[gdp_s['year'] >= entry_year]
    exited_above = (after['gdp_pc_ppp'] > 25000).any()
    fell_below = (after['gdp_pc_ppp'] < 9000).any()

    # Get characteristics at entry
    entry_row = c[c['year'] == entry_year]
    if len(entry_row) == 0:
        continue

    profile = {
        'iso3': iso,
        'entry_year': entry_year,
        'exited_above': int(exited_above),
        'fell_below': int(fell_below),
        'gdp_at_entry': entry_row['gdp_pc_ppp'].values[0],
    }
    for var in ['Z_1', 'old_dep', 'youth_dep', 'working_age_share', 'kaopen',
                'ca_gdp', 'gross_savings_gdp', 'fiscal_bal_gdp', 'nfa_gdp',
                'resource_rents_gdp', 'rgdp_growth']:
        if var in entry_row.columns:
            profile[var] = entry_row[var].values[0]

    zone_profiles.append(profile)

zp = pd.DataFrame(zone_profiles)

# Compare crossers vs non-crossers
crossers = zp[zp['exited_above'] == 1]
non_crossers = zp[zp['exited_above'] == 0]

print(f"\nCountries that entered threshold zone: {len(zp)}")
print(f"  Exited above $25k: {len(crossers)}")
print(f"  Still in zone or fell back: {len(non_crossers)}")

compare_vars = ['Z_1', 'old_dep', 'youth_dep', 'working_age_share', 'kaopen',
                'ca_gdp', 'gross_savings_gdp', 'fiscal_bal_gdp',
                'resource_rents_gdp', 'rgdp_growth']

print(f"\n{'Variable':25s} {'Crossers':>10s} {'Non-crossers':>13s} {'Diff':>8s} {'p-value':>8s}")
print("-" * 70)
from scipy import stats
for var in compare_vars:
    if var in zp.columns:
        c_vals = crossers[var].dropna()
        nc_vals = non_crossers[var].dropna()
        if len(c_vals) > 3 and len(nc_vals) > 3:
            t, p = stats.ttest_ind(c_vals, nc_vals, equal_var=False)
            diff = c_vals.mean() - nc_vals.mean()
            sig = '***' if p<0.01 else '**' if p<0.05 else '*' if p<0.1 else ''
            print(f"  {var:23s} {c_vals.mean():10.3f} {nc_vals.mean():13.3f} {diff:8.3f} {p:7.4f}{sig}")

# ═══════════════════════════════════════════════════════════════════════
# 5. How long do countries spend in the zone?
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("5. TIME IN ZONE")
print("=" * 70)

for iso in zone_countries:
    c = panel[(panel['iso3'] == iso) & (panel['gdp_pc_ppp'].notna())]
    in_zone = c[(c['gdp_pc_ppp'] >= 9000) & (c['gdp_pc_ppp'] <= 25000)]
    if len(in_zone) > 0:
        years_in = in_zone['year'].max() - in_zone['year'].min()
        zp.loc[zp['iso3'] == iso, 'years_in_zone'] = years_in

c_yrs = zp[zp['exited_above']==1]['years_in_zone'].dropna()
nc_yrs = zp[zp['exited_above']==0]['years_in_zone'].dropna()

print(f"\nYears spent in $9k-$25k zone:")
print(f"  Crossers: mean={c_yrs.mean():.1f}, median={c_yrs.median():.0f}")
print(f"  Non-crossers: mean={nc_yrs.mean():.1f}, median={nc_yrs.median():.0f}")

print("\n✓ Probe complete.")
